{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Example of generating data for training/evaluation\n",
    "---\n",
    "\n",
    "\n",
    "## Table of Contents\n",
    "\n",
    "1. [Data generator for training CFO Correction Network](#cfo_data_generator)\n",
    "2. [Data generator for training Equalization](#equalization)\n",
    "3. [Data generator for training Demod + Decoder](#decoder)\n",
    "4. [Data generator for training End to End](#end2end)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Environment setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "# Import radioml package from other directory\n",
    "module_path = os.path.abspath(os.path.join('..'))\n",
    "if module_path not in sys.path:\n",
    "    sys.path.append(module_path)    \n",
    "\n",
    "# For visualization\n",
    "import imageio\n",
    "from IPython.display import Image, display\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import pylab\n",
    "sns.set_style(\"darkgrid\")\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_PACKETS  = 5\n",
    "PREAMBLE_LEN = 40\n",
    "DATA_LEN     = 200\n",
    "CHANNEL_LEN = 1\n",
    "OMEGA      = 1/50\n",
    "SNR_train  = 10.0 "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualization Funcs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.ioff()\n",
    "\n",
    "\n",
    "def visualize_signals(ax, x, y, groundtruths=None, title=None, cmap='Set1', min_val=-2, max_val=2):\n",
    "    ax.scatter(x, y, c=groundtruths, s=100, cmap=cmap)\n",
    "    ax.set_xlabel('I-component')\n",
    "    ax.set_ylabel('Q-component')\n",
    "    ax.set_title(title)\n",
    "    ax.axhline()\n",
    "    ax.axvline()\n",
    "#     ax.set_xlim(min_val, max_val)\n",
    "#     ax.set_ylim(min_val, max_val)\n",
    "\n",
    "def visualize_cfo(radio, omega, snr):\n",
    "    \"\"\"Visualize the effect of Carrier Offset Frequency.\"\"\"\n",
    "    generator = cfo_net_data_generator(radio, omega, snr, batch_size=4)\n",
    "\n",
    "    [preambles, preambles_conv], cfo_corrected_preamble = next(generator)\n",
    "    symbols, groundtruths = np.unique(preambles.view(np.complex),return_inverse=True)\n",
    "\n",
    "    fig, (left_ax, right_ax) = plt.subplots(1, 2, figsize=(10, 5))\n",
    "    st = fig.suptitle(\"Omega = %.3f || SNR_dB = %.2f \" % (omega, snr))\n",
    "    visualize_signals(left_ax, \n",
    "                     x=preambles_conv[...,0].flatten(), \n",
    "                     y=preambles_conv[...,1].flatten(),\n",
    "                     groundtruths=groundtruths, \n",
    "                     title=\"Before CFO Correction\", \n",
    "                     cmap='Set1')\n",
    "    visualize_signals(right_ax, \n",
    "                     x=cfo_corrected_preamble[...,0].flatten(), \n",
    "                     y=cfo_corrected_preamble[...,1].flatten(),\n",
    "                     groundtruths=groundtruths, \n",
    "                     title=\"After CFO Correction\")\n",
    "    img = _convert_fig_to_img(fig)\n",
    "    return img\n",
    "    \n",
    "\n",
    "def _convert_fig_to_img(fig):\n",
    "    \"\"\"Convert figure to image.\"\"\"\n",
    "    # Used to return the plot as an image rray\n",
    "    fig.canvas.draw()       # draw the canvas, cache the renderer\n",
    "    image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')\n",
    "    image  = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))\n",
    "    return image"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Example of data generator for training CFO Correction  <a class=\"anchor\" id=\"cfo_data_generator\"></a>\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "from radioml.core import RadioTransmitter\n",
    "from radioml.dataset import cfo_net_data_generator\n",
    "\n",
    "channel_len = 1\n",
    "radio_transmitter = RadioTransmitter(DATA_LEN, PREAMBLE_LEN,'QPSK')\n",
    "generator = cfo_net_data_generator(radio_transmitter, OMEGA, SNR_train, batch_size=4)\n",
    "\n",
    "for i in range(3):\n",
    "    one_batch = next(generator)\n",
    "    [preambles, convolved_preambles], cfo_corrected_preambles = one_batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "fnames = []\n",
    "\n",
    "radios = [RadioTransmitter(0, PREAMBLE_LEN, mod) for mod in  ['QPSK', 'QAM16']]\n",
    "\n",
    "# Save and Display the animated result \n",
    "filename = './figures/cfo_%s.gif' % 'qam16'\n",
    "imageio.mimsave(filename, \n",
    "                [visualize_cfo(radio, omega, 30.0) for radio in radios for omega in [1/100., 1/50, 1/30]], \n",
    "                fps=0.5)\n",
    "    \n",
    "with open(filename,'rb') as f:\n",
    "    display(Image(data=f.read(), format='png'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Example of data generatorfor training Equalization  <a class=\"anchor\" id=\"equalization\"></a>\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def visualize_equalization(radio, channel_len, snr):\n",
    "    \"\"\"Visualize the effect of Symbol Inteference.\"\"\"\n",
    "    \n",
    "    examples = next(equalization_data_generator(radio, \n",
    "                                                channel_len,\n",
    "                                                snr, \n",
    "                                                batch_size=1, \n",
    "                                                num_cpus=8))\n",
    "    \n",
    "    [preambles, _, cfo_corrected_data], \\\n",
    "    [equalized_data, modulated_packet] = examples   \n",
    "    \n",
    "    modulated_packet = modulated_packet[:, radio.modulated_preamble_len:]\n",
    "    symbols, groundtruths = np.unique(modulated_packet.view(np.complex),return_inverse=True)\n",
    "    fig, (left_ax, right_ax) = plt.subplots(1, 2, figsize=(10,5))\n",
    "    st = fig.suptitle(\"Number of Channels = %d\" % channel_len)\n",
    "    \n",
    "    visualize_signals(left_ax, \n",
    "                      cfo_corrected_data[...,0].flatten(),\n",
    "                      cfo_corrected_data[...,1].flatten(), \n",
    "                      groundtruths,\n",
    "                      title=\"Before Equalization\", \n",
    "                      min_val=-3,max_val=3, cmap='rainbow')\n",
    "    visualize_signals(right_ax, \n",
    "                      equalized_data[...,0].flatten(),\n",
    "                      equalized_data[...,1].flatten(),\n",
    "                      groundtruths,\n",
    "                      title=\"After Equalization\", \n",
    "                      min_val=-3, max_val=3)\n",
    "    img = _convert_fig_to_img(fig)\n",
    "    return img"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "from radioml.core import RadioTransmitter\n",
    "from radioml.dataset import equalization_data_generator\n",
    "\n",
    "radio = RadioTransmitter(DATA_LEN,PREAMBLE_LEN, modulation_scheme='QPSK')\n",
    "generator = equalization_data_generator(radio, CHANNEL_LEN, SNR_train, batch_size=32)\n",
    "\n",
    "for i in range(3):\n",
    "    one_batch = next(generator)\n",
    "    [preambles, cfo_corrected_preamble, cfo_corrected_preamble], channels_etimates = one_batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Visualize equalization data\n",
    "fnames=[]\n",
    "radio = RadioTransmitter(DATA_LEN,PREAMBLE_LEN, modulation_scheme='QPSK')\n",
    "\n",
    "CHANNEL_TAPS = [2]\n",
    "for modulation_scheme in ['QPSK']:\n",
    "    filename = './figures/channel_inference_%s.gif'%modulation_scheme\n",
    "    imageio.mimsave(filename, \n",
    "                    [visualize_equalization(radio, channel_tap, 50.0) \n",
    "                      for channel_tap in CHANNEL_TAPS], fps=2)\n",
    "    fnames.append(filename)\n",
    "\n",
    "with open(filename,'rb') as f:\n",
    "    display(Image(data=f.read(), format='png'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data generator for training Demod + Decoder <a class=\"anchor\" id=\"decoder\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def visualize_demod_n_decode(radio, snr):\n",
    "    [equalized_data, modulated_data], \\\n",
    "    data_estimate = next(ecc_data_generator(radio, snr, batch_size=4))\n",
    "    \n",
    "    symbols, groundtruths = np.unique(modulated_data.view(np.complex), return_inverse=True)\n",
    "    fig, ax = plt.subplots(1, 1)\n",
    "    \n",
    "    visualize_signals(ax, \n",
    "                      equalized_data[...,0].flatten(),\n",
    "                      equalized_data[...,1].flatten(), \n",
    "                      groundtruths,\n",
    "                      title=\"Equalized singals at $SNR_{dB}$ = %.2f dB\"%snr)\n",
    "\n",
    "    img = _convert_fig_to_img(fig)\n",
    "    return img"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 2.61 s, sys: 252 ms, total: 2.86 s\n",
      "Wall time: 3.56 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "from radioml.core import RadioTransmitter\n",
    "from radioml.dataset import ecc_data_generator\n",
    "\n",
    "radio = RadioTransmitter(DATA_LEN,PREAMBLE_LEN, modulation_scheme='QPSK')\n",
    "generator = ecc_data_generator(radio, SNR_train, batch_size=32)\n",
    "\n",
    "for i in range(3):\n",
    "    one_batch = next(generator)\n",
    "    equalized_data, data_estimate = one_batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<img src=\"./figures/error_correction.gif\"/>"
      ],
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "SNRs = [13.0, 15.0, 17.0, 20.0, 30.0]\n",
    "\n",
    "radio = RadioTransmitter(DATA_LEN,PREAMBLE_LEN, modulation_scheme='QPSK')\n",
    "imageio.mimsave('./figures/error_correction.gif', \n",
    "                   [visualize_demod_n_decode(radio, snr) for snr in SNRs], \n",
    "                   fps=1)\n",
    "display(Image(url='./figures/error_correction.gif'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "radio = RadioDataGenerator(DATA_LEN,PREAMBLE_LEN, CHANNEL_LEN, modulation_scheme='QPSK')\n",
    "generator = radio.end2end_data_generator(OMEGA, SNR_train, batch_size=32)\n",
    "\n",
    "for i in range(3):\n",
    "    one_batch = next(generator)\n",
    "    [preambles, corrupted_packets], data_etimates = one_batch"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Radio ML",
   "language": "python",
   "name": "radioml"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
